{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "pytorch",
      "language": "python",
      "name": "pytorch"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.6"
    },
    "colab": {
      "name": "Copy of mini-batch-logistic-regression-evaluator.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "149b9ce8fb68473a837a77431c12281a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_88cd3db2831e4c13a4a634709700d6b2",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_a88c31d74f5c40a2b24bcff5a35d216c",
              "IPY_MODEL_60c6150177694717a622936b830427b5"
            ]
          }
        },
        "88cd3db2831e4c13a4a634709700d6b2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "a88c31d74f5c40a2b24bcff5a35d216c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_dba019efadee4fdc8c799f309b9a7e70",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_5901c2829a554c8ebbd5926610088041"
          }
        },
        "60c6150177694717a622936b830427b5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_957362a11d174407979cf17012bf9208",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 2640404480/? [00:51&lt;00:00, 32685718.58it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_a4f82234388e4701a02a9f68a177193a"
          }
        },
        "dba019efadee4fdc8c799f309b9a7e70": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "5901c2829a554c8ebbd5926610088041": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "957362a11d174407979cf17012bf9208": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "a4f82234388e4701a02a9f68a177193a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YUemQib7ZE4D"
      },
      "source": [
        "import torch\n",
        "import sys\n",
        "import numpy as np\n",
        "import os\n",
        "import yaml\n",
        "import matplotlib.pyplot as plt\n",
        "import torchvision"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WSgRE1CcLqdS",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "48a2ae15-f672-495b-8d43-9a23b85fa3b8"
      },
      "source": [
        "!pip install gdown"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: gdown in /usr/local/lib/python3.6/dist-packages (3.6.4)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from gdown) (1.15.0)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gdown) (2.23.0)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gdown) (4.41.1)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2020.12.5)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (1.24.3)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (3.0.4)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2.10)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NOIJEui1ZziV"
      },
      "source": [
        "def get_file_id_by_model(folder_name):\n",
        "  file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',\n",
        "             'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',\n",
        "             'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}\n",
        "  return file_id.get(folder_name, \"Model not found.\")"
      ],
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "G7YMxsvEZMrX",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "59475430-69d2-45a2-b61b-ae755d5d6e88"
      },
      "source": [
        "folder_name = 'resnet50_50-epochs_stl10'\n",
        "file_id = get_file_id_by_model(folder_name)\n",
        "print(folder_name, file_id)"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "resnet50_50-epochs_stl10 1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PWZ8fet_YoJm",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "fbaeb858-221b-4d1b-dd90-001a6e713b75"
      },
      "source": [
        "# download and extract model files\n",
        "os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
        "os.system('unzip {}'.format(folder_name))\n",
        "!ls"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "checkpoint_0040.pth.tar\n",
            "config.yml\n",
            "events.out.tfevents.1610927742.4cb2c837708d.2694093.0\n",
            "resnet50_50-epochs_stl10.zip\n",
            "sample_data\n",
            "training.log\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3_nypQVEv-hn"
      },
      "source": [
        "from torch.utils.data import DataLoader\n",
        "import torchvision.transforms as transforms\n",
        "from torchvision import datasets"
      ],
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lDfbL3w_Z0Od",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "7532966e-1c4a-4641-c928-4cda14c53389"
      },
      "source": [
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "print(\"Using device:\", device)"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Using device: cuda\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BfIPl0G6_RrT"
      },
      "source": [
        "def get_stl10_data_loaders(download, shuffle=False, batch_size=256):\n",
        "  train_dataset = datasets.STL10('./data', split='train', download=download,\n",
        "                                  transform=transforms.ToTensor())\n",
        "\n",
        "  train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
        "                            num_workers=0, drop_last=False, shuffle=shuffle)\n",
        "  \n",
        "  test_dataset = datasets.STL10('./data', split='test', download=download,\n",
        "                                  transform=transforms.ToTensor())\n",
        "\n",
        "  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n",
        "                            num_workers=10, drop_last=False, shuffle=shuffle)\n",
        "  return train_loader, test_loader\n",
        "\n",
        "def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):\n",
        "  train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n",
        "                                  transform=transforms.ToTensor())\n",
        "\n",
        "  train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
        "                            num_workers=0, drop_last=False, shuffle=shuffle)\n",
        "  \n",
        "  test_dataset = datasets.CIFAR10('./data', train=False, download=download,\n",
        "                                  transform=transforms.ToTensor())\n",
        "\n",
        "  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n",
        "                            num_workers=10, drop_last=False, shuffle=shuffle)\n",
        "  return train_loader, test_loader"
      ],
      "execution_count": 17,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6N8lYkbmDTaK"
      },
      "source": [
        "with open(os.path.join('./config.yml')) as file:\n",
        "  config = yaml.load(file)"
      ],
      "execution_count": 18,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "a18lPD-tIle6"
      },
      "source": [
        "if config.arch == 'resnet18':\n",
        "  model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)\n",
        "elif config.arch == 'resnet50':\n",
        "  model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)"
      ],
      "execution_count": 19,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4AIfgq41GuTT"
      },
      "source": [
        "checkpoint = torch.load('checkpoint_0040.pth.tar', map_location=device)\n",
        "state_dict = checkpoint['state_dict']\n",
        "\n",
        "for k in list(state_dict.keys()):\n",
        "\n",
        "  if k.startswith('backbone.'):\n",
        "    if k.startswith('backbone') and not k.startswith('backbone.fc'):\n",
        "      # remove prefix\n",
        "      state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n",
        "  del state_dict[k]"
      ],
      "execution_count": 21,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VVjA83PPJYWl"
      },
      "source": [
        "log = model.load_state_dict(state_dict, strict=False)\n",
        "assert log.missing_keys == ['fc.weight', 'fc.bias']"
      ],
      "execution_count": 22,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_GC0a14uWRr6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 117,
          "referenced_widgets": [
            "149b9ce8fb68473a837a77431c12281a",
            "88cd3db2831e4c13a4a634709700d6b2",
            "a88c31d74f5c40a2b24bcff5a35d216c",
            "60c6150177694717a622936b830427b5",
            "dba019efadee4fdc8c799f309b9a7e70",
            "5901c2829a554c8ebbd5926610088041",
            "957362a11d174407979cf17012bf9208",
            "a4f82234388e4701a02a9f68a177193a"
          ]
        },
        "outputId": "4c2558db-921c-425e-f947-6cc746d8c749"
      },
      "source": [
        "if config.dataset_name == 'cifar10':\n",
        "  train_loader, test_loader = get_cifar10_data_loaders(download=True)\n",
        "elif config.dataset_name == 'stl10':\n",
        "  train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
        "print(\"Dataset:\", config.dataset_name)"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "149b9ce8fb68473a837a77431c12281a",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting ./data/stl10_binary.tar.gz to ./data\n",
            "Files already downloaded and verified\n",
            "Dataset: stl10\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pYT_KsM0Mnnr"
      },
      "source": [
        "# freeze all layers but the last fc\n",
        "for name, param in model.named_parameters():\n",
        "    if name not in ['fc.weight', 'fc.bias']:\n",
        "        param.requires_grad = False\n",
        "\n",
        "parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n",
        "assert len(parameters) == 2  # fc.weight, fc.bias"
      ],
      "execution_count": 24,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aPVh1S_eMRDU"
      },
      "source": [
        "optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n",
        "criterion = torch.nn.CrossEntropyLoss().to(device)"
      ],
      "execution_count": 25,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "edr6RhP2PdVq"
      },
      "source": [
        "def accuracy(output, target, topk=(1,)):\n",
        "    \"\"\"Computes the accuracy over the k top predictions for the specified values of k\"\"\"\n",
        "    with torch.no_grad():\n",
        "        maxk = max(topk)\n",
        "        batch_size = target.size(0)\n",
        "\n",
        "        _, pred = output.topk(maxk, 1, True, True)\n",
        "        pred = pred.t()\n",
        "        correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
        "\n",
        "        res = []\n",
        "        for k in topk:\n",
        "            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n",
        "            res.append(correct_k.mul_(100.0 / batch_size))\n",
        "        return res"
      ],
      "execution_count": 26,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qOder0dAMI7X",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "5f723b91-5a5e-43eb-ca01-a9b5ae2f1346"
      },
      "source": [
        "epochs = 100\n",
        "for epoch in range(epochs):\n",
        "  top1_train_accuracy = 0\n",
        "  for counter, (x_batch, y_batch) in enumerate(train_loader):\n",
        "    x_batch = x_batch.to(device)\n",
        "    y_batch = y_batch.to(device)\n",
        "\n",
        "    logits = model(x_batch)\n",
        "    loss = criterion(logits, y_batch)\n",
        "    top1 = accuracy(logits, y_batch, topk=(1,))\n",
        "    top1_train_accuracy += top1[0]\n",
        "\n",
        "    optimizer.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "  top1_train_accuracy /= (counter + 1)\n",
        "  top1_accuracy = 0\n",
        "  top5_accuracy = 0\n",
        "  for counter, (x_batch, y_batch) in enumerate(test_loader):\n",
        "    x_batch = x_batch.to(device)\n",
        "    y_batch = y_batch.to(device)\n",
        "\n",
        "    logits = model(x_batch)\n",
        "  \n",
        "    top1, top5 = accuracy(logits, y_batch, topk=(1,5))\n",
        "    top1_accuracy += top1[0]\n",
        "    top5_accuracy += top5[0]\n",
        "  \n",
        "  top1_accuracy /= (counter + 1)\n",
        "  top5_accuracy /= (counter + 1)\n",
        "  print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 0\tTop1 Train accuracy 28.7109375\tTop1 Test accuracy: 43.75\tTop5 test acc: 93.837890625\n",
            "Epoch 1\tTop1 Train accuracy 49.37959671020508\tTop1 Test accuracy: 52.8662109375\tTop5 test acc: 95.439453125\n",
            "Epoch 2\tTop1 Train accuracy 55.257354736328125\tTop1 Test accuracy: 56.45263671875\tTop5 test acc: 95.91796875\n",
            "Epoch 3\tTop1 Train accuracy 57.51838302612305\tTop1 Test accuracy: 57.39013671875\tTop5 test acc: 96.19384765625\n",
            "Epoch 4\tTop1 Train accuracy 58.727020263671875\tTop1 Test accuracy: 58.2568359375\tTop5 test acc: 96.435546875\n",
            "Epoch 5\tTop1 Train accuracy 59.677162170410156\tTop1 Test accuracy: 58.7353515625\tTop5 test acc: 96.50390625\n",
            "Epoch 6\tTop1 Train accuracy 60.065486907958984\tTop1 Test accuracy: 59.17724609375\tTop5 test acc: 96.708984375\n",
            "Epoch 7\tTop1 Train accuracy 60.612361907958984\tTop1 Test accuracy: 59.482421875\tTop5 test acc: 96.74560546875\n",
            "Epoch 8\tTop1 Train accuracy 60.827205657958984\tTop1 Test accuracy: 59.66064453125\tTop5 test acc: 96.77490234375\n",
            "Epoch 9\tTop1 Train accuracy 61.100643157958984\tTop1 Test accuracy: 60.09521484375\tTop5 test acc: 96.82373046875\n",
            "Epoch 10\tTop1 Train accuracy 61.52803421020508\tTop1 Test accuracy: 60.3466796875\tTop5 test acc: 96.82861328125\n",
            "Epoch 11\tTop1 Train accuracy 61.80147171020508\tTop1 Test accuracy: 60.6640625\tTop5 test acc: 96.8896484375\n",
            "Epoch 12\tTop1 Train accuracy 62.09444046020508\tTop1 Test accuracy: 60.96435546875\tTop5 test acc: 96.99462890625\n",
            "Epoch 13\tTop1 Train accuracy 62.541358947753906\tTop1 Test accuracy: 61.13037109375\tTop5 test acc: 97.0068359375\n",
            "Epoch 14\tTop1 Train accuracy 62.853858947753906\tTop1 Test accuracy: 61.34033203125\tTop5 test acc: 97.01904296875\n",
            "Epoch 15\tTop1 Train accuracy 62.951515197753906\tTop1 Test accuracy: 61.5673828125\tTop5 test acc: 96.99951171875\n",
            "Epoch 16\tTop1 Train accuracy 63.400733947753906\tTop1 Test accuracy: 61.806640625\tTop5 test acc: 97.0361328125\n",
            "Epoch 17\tTop1 Train accuracy 63.66958236694336\tTop1 Test accuracy: 61.98974609375\tTop5 test acc: 97.0849609375\n",
            "Epoch 18\tTop1 Train accuracy 63.82583236694336\tTop1 Test accuracy: 62.265625\tTop5 test acc: 97.07275390625\n",
            "Epoch 19\tTop1 Train accuracy 64.1187973022461\tTop1 Test accuracy: 62.412109375\tTop5 test acc: 97.09716796875\n",
            "Epoch 20\tTop1 Train accuracy 64.2750473022461\tTop1 Test accuracy: 62.56591796875\tTop5 test acc: 97.12158203125\n",
            "Epoch 21\tTop1 Train accuracy 64.4140625\tTop1 Test accuracy: 62.724609375\tTop5 test acc: 97.20703125\n",
            "Epoch 22\tTop1 Train accuracy 64.53125\tTop1 Test accuracy: 62.90771484375\tTop5 test acc: 97.255859375\n",
            "Epoch 23\tTop1 Train accuracy 64.6484375\tTop1 Test accuracy: 62.95654296875\tTop5 test acc: 97.29248046875\n",
            "Epoch 24\tTop1 Train accuracy 64.86328125\tTop1 Test accuracy: 63.12255859375\tTop5 test acc: 97.35595703125\n",
            "Epoch 25\tTop1 Train accuracy 65.1344223022461\tTop1 Test accuracy: 63.330078125\tTop5 test acc: 97.40478515625\n",
            "Epoch 26\tTop1 Train accuracy 65.3297348022461\tTop1 Test accuracy: 63.3984375\tTop5 test acc: 97.44873046875\n",
            "Epoch 27\tTop1 Train accuracy 65.4469223022461\tTop1 Test accuracy: 63.34228515625\tTop5 test acc: 97.412109375\n",
            "Epoch 28\tTop1 Train accuracy 65.6227035522461\tTop1 Test accuracy: 63.48876953125\tTop5 test acc: 97.412109375\n",
            "Epoch 29\tTop1 Train accuracy 65.85478210449219\tTop1 Test accuracy: 63.56201171875\tTop5 test acc: 97.42431640625\n",
            "Epoch 30\tTop1 Train accuracy 66.06732940673828\tTop1 Test accuracy: 63.67431640625\tTop5 test acc: 97.4560546875\n",
            "Epoch 31\tTop1 Train accuracy 66.20404815673828\tTop1 Test accuracy: 63.80859375\tTop5 test acc: 97.48046875\n",
            "Epoch 32\tTop1 Train accuracy 66.24080657958984\tTop1 Test accuracy: 63.92578125\tTop5 test acc: 97.5048828125\n",
            "Epoch 33\tTop1 Train accuracy 66.58777618408203\tTop1 Test accuracy: 63.9990234375\tTop5 test acc: 97.529296875\n",
            "Epoch 34\tTop1 Train accuracy 66.70496368408203\tTop1 Test accuracy: 64.1455078125\tTop5 test acc: 97.51708984375\n",
            "Epoch 35\tTop1 Train accuracy 66.80261993408203\tTop1 Test accuracy: 64.20654296875\tTop5 test acc: 97.529296875\n",
            "Epoch 36\tTop1 Train accuracy 66.91980743408203\tTop1 Test accuracy: 64.32861328125\tTop5 test acc: 97.51708984375\n",
            "Epoch 37\tTop1 Train accuracy 66.93933868408203\tTop1 Test accuracy: 64.3896484375\tTop5 test acc: 97.51708984375\n",
            "Epoch 38\tTop1 Train accuracy 66.97840118408203\tTop1 Test accuracy: 64.47021484375\tTop5 test acc: 97.529296875\n",
            "Epoch 39\tTop1 Train accuracy 67.11282348632812\tTop1 Test accuracy: 64.53125\tTop5 test acc: 97.56591796875\n",
            "Epoch 40\tTop1 Train accuracy 67.24954223632812\tTop1 Test accuracy: 64.6044921875\tTop5 test acc: 97.6025390625\n",
            "Epoch 41\tTop1 Train accuracy 67.34949493408203\tTop1 Test accuracy: 64.62890625\tTop5 test acc: 97.59033203125\n",
            "Epoch 42\tTop1 Train accuracy 67.42761993408203\tTop1 Test accuracy: 64.7265625\tTop5 test acc: 97.6025390625\n",
            "Epoch 43\tTop1 Train accuracy 67.52527618408203\tTop1 Test accuracy: 64.84375\tTop5 test acc: 97.61474609375\n",
            "Epoch 44\tTop1 Train accuracy 67.58386993408203\tTop1 Test accuracy: 64.87548828125\tTop5 test acc: 97.61474609375\n",
            "Epoch 45\tTop1 Train accuracy 67.64246368408203\tTop1 Test accuracy: 64.9365234375\tTop5 test acc: 97.626953125\n",
            "Epoch 46\tTop1 Train accuracy 67.75735473632812\tTop1 Test accuracy: 65.0341796875\tTop5 test acc: 97.66357421875\n",
            "Epoch 47\tTop1 Train accuracy 67.85501098632812\tTop1 Test accuracy: 65.1318359375\tTop5 test acc: 97.7001953125\n",
            "Epoch 48\tTop1 Train accuracy 67.89407348632812\tTop1 Test accuracy: 65.1318359375\tTop5 test acc: 97.73681640625\n",
            "Epoch 49\tTop1 Train accuracy 67.95266723632812\tTop1 Test accuracy: 65.15625\tTop5 test acc: 97.73681640625\n",
            "Epoch 50\tTop1 Train accuracy 68.01126098632812\tTop1 Test accuracy: 65.21728515625\tTop5 test acc: 97.76123046875\n",
            "Epoch 51\tTop1 Train accuracy 68.05032348632812\tTop1 Test accuracy: 65.29052734375\tTop5 test acc: 97.7490234375\n",
            "Epoch 52\tTop1 Train accuracy 68.05032348632812\tTop1 Test accuracy: 65.3564453125\tTop5 test acc: 97.78564453125\n",
            "Epoch 53\tTop1 Train accuracy 68.20657348632812\tTop1 Test accuracy: 65.3759765625\tTop5 test acc: 97.7978515625\n",
            "Epoch 54\tTop1 Train accuracy 68.28469848632812\tTop1 Test accuracy: 65.45654296875\tTop5 test acc: 97.822265625\n",
            "Epoch 55\tTop1 Train accuracy 68.41912078857422\tTop1 Test accuracy: 65.46875\tTop5 test acc: 97.8466796875\n",
            "Epoch 56\tTop1 Train accuracy 68.45818328857422\tTop1 Test accuracy: 65.5615234375\tTop5 test acc: 97.85888671875\n",
            "Epoch 57\tTop1 Train accuracy 68.61443328857422\tTop1 Test accuracy: 65.56640625\tTop5 test acc: 97.87109375\n",
            "Epoch 58\tTop1 Train accuracy 68.71208953857422\tTop1 Test accuracy: 65.5859375\tTop5 test acc: 97.90771484375\n",
            "Epoch 59\tTop1 Train accuracy 68.69255828857422\tTop1 Test accuracy: 65.64697265625\tTop5 test acc: 97.919921875\n",
            "Epoch 60\tTop1 Train accuracy 68.80744934082031\tTop1 Test accuracy: 65.64697265625\tTop5 test acc: 97.93212890625\n",
            "Epoch 61\tTop1 Train accuracy 68.94416809082031\tTop1 Test accuracy: 65.72021484375\tTop5 test acc: 97.93212890625\n",
            "Epoch 62\tTop1 Train accuracy 69.04182434082031\tTop1 Test accuracy: 65.76904296875\tTop5 test acc: 97.919921875\n",
            "Epoch 63\tTop1 Train accuracy 69.06135559082031\tTop1 Test accuracy: 65.84228515625\tTop5 test acc: 97.90771484375\n",
            "Epoch 64\tTop1 Train accuracy 69.19807434082031\tTop1 Test accuracy: 65.93505859375\tTop5 test acc: 97.90771484375\n",
            "Epoch 65\tTop1 Train accuracy 69.23713684082031\tTop1 Test accuracy: 65.95947265625\tTop5 test acc: 97.9150390625\n",
            "Epoch 66\tTop1 Train accuracy 69.25666809082031\tTop1 Test accuracy: 66.0888671875\tTop5 test acc: 97.939453125\n",
            "Epoch 67\tTop1 Train accuracy 69.31526184082031\tTop1 Test accuracy: 66.02783203125\tTop5 test acc: 97.939453125\n",
            "Epoch 68\tTop1 Train accuracy 69.43014526367188\tTop1 Test accuracy: 66.07666015625\tTop5 test acc: 97.9638671875\n",
            "Epoch 69\tTop1 Train accuracy 69.48873901367188\tTop1 Test accuracy: 66.12060546875\tTop5 test acc: 97.9638671875\n",
            "Epoch 70\tTop1 Train accuracy 69.50827026367188\tTop1 Test accuracy: 66.083984375\tTop5 test acc: 97.95166015625\n",
            "Epoch 71\tTop1 Train accuracy 69.60592651367188\tTop1 Test accuracy: 66.1572265625\tTop5 test acc: 97.9638671875\n",
            "Epoch 72\tTop1 Train accuracy 69.68635559082031\tTop1 Test accuracy: 66.2060546875\tTop5 test acc: 97.95166015625\n",
            "Epoch 73\tTop1 Train accuracy 69.78170776367188\tTop1 Test accuracy: 66.2744140625\tTop5 test acc: 97.92724609375\n",
            "Epoch 74\tTop1 Train accuracy 69.84030151367188\tTop1 Test accuracy: 66.31591796875\tTop5 test acc: 97.92724609375\n",
            "Epoch 75\tTop1 Train accuracy 69.89889526367188\tTop1 Test accuracy: 66.328125\tTop5 test acc: 97.9150390625\n",
            "Epoch 76\tTop1 Train accuracy 69.93795776367188\tTop1 Test accuracy: 66.41357421875\tTop5 test acc: 97.92724609375\n",
            "Epoch 77\tTop1 Train accuracy 69.95748901367188\tTop1 Test accuracy: 66.41357421875\tTop5 test acc: 97.9150390625\n",
            "Epoch 78\tTop1 Train accuracy 70.01608276367188\tTop1 Test accuracy: 66.474609375\tTop5 test acc: 97.9150390625\n",
            "Epoch 79\tTop1 Train accuracy 69.99655151367188\tTop1 Test accuracy: 66.53564453125\tTop5 test acc: 97.939453125\n",
            "Epoch 80\tTop1 Train accuracy 70.01608276367188\tTop1 Test accuracy: 66.56005859375\tTop5 test acc: 97.939453125\n",
            "Epoch 81\tTop1 Train accuracy 70.09420776367188\tTop1 Test accuracy: 66.56494140625\tTop5 test acc: 97.939453125\n",
            "Epoch 82\tTop1 Train accuracy 70.11373901367188\tTop1 Test accuracy: 66.650390625\tTop5 test acc: 97.939453125\n",
            "Epoch 83\tTop1 Train accuracy 70.19186401367188\tTop1 Test accuracy: 66.71142578125\tTop5 test acc: 97.92724609375\n",
            "Epoch 84\tTop1 Train accuracy 70.26998901367188\tTop1 Test accuracy: 66.7236328125\tTop5 test acc: 97.90283203125\n",
            "Epoch 85\tTop1 Train accuracy 70.32858276367188\tTop1 Test accuracy: 66.73583984375\tTop5 test acc: 97.90283203125\n",
            "Epoch 86\tTop1 Train accuracy 70.32858276367188\tTop1 Test accuracy: 66.748046875\tTop5 test acc: 97.890625\n",
            "Epoch 87\tTop1 Train accuracy 70.46530151367188\tTop1 Test accuracy: 66.7724609375\tTop5 test acc: 97.890625\n",
            "Epoch 88\tTop1 Train accuracy 70.52389526367188\tTop1 Test accuracy: 66.78466796875\tTop5 test acc: 97.90283203125\n",
            "Epoch 89\tTop1 Train accuracy 70.56295776367188\tTop1 Test accuracy: 66.78466796875\tTop5 test acc: 97.890625\n",
            "Epoch 90\tTop1 Train accuracy 70.68014526367188\tTop1 Test accuracy: 66.83349609375\tTop5 test acc: 97.87841796875\n",
            "Epoch 91\tTop1 Train accuracy 70.77780151367188\tTop1 Test accuracy: 66.826171875\tTop5 test acc: 97.87841796875\n",
            "Epoch 92\tTop1 Train accuracy 70.81686401367188\tTop1 Test accuracy: 66.88720703125\tTop5 test acc: 97.87841796875\n",
            "Epoch 93\tTop1 Train accuracy 70.85592651367188\tTop1 Test accuracy: 66.8994140625\tTop5 test acc: 97.87841796875\n",
            "Epoch 94\tTop1 Train accuracy 70.91452026367188\tTop1 Test accuracy: 66.9482421875\tTop5 test acc: 97.890625\n",
            "Epoch 95\tTop1 Train accuracy 71.03170776367188\tTop1 Test accuracy: 66.98486328125\tTop5 test acc: 97.890625\n",
            "Epoch 96\tTop1 Train accuracy 71.09030151367188\tTop1 Test accuracy: 67.001953125\tTop5 test acc: 97.91015625\n",
            "Epoch 97\tTop1 Train accuracy 71.09030151367188\tTop1 Test accuracy: 67.0263671875\tTop5 test acc: 97.91015625\n",
            "Epoch 98\tTop1 Train accuracy 71.12936401367188\tTop1 Test accuracy: 67.06298828125\tTop5 test acc: 97.89794921875\n",
            "Epoch 99\tTop1 Train accuracy 71.12936401367188\tTop1 Test accuracy: 67.0751953125\tTop5 test acc: 97.8857421875\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dtYqHZirMNZk"
      },
      "source": [
        ""
      ],
      "execution_count": 27,
      "outputs": []
    }
  ]
}